# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Transformers for Google GenAI SDK."""

import base64
from collections.abc import Iterable, Mapping
from enum import Enum, EnumMeta
import inspect
import io
import logging
import re
import sys
import time
import types as builtin_types
import typing
from typing import Any, GenericAlias, List, Optional, Sequence, Union  # type: ignore[attr-defined]
from ._mcp_utils import mcp_to_gemini_tool
from ._common import get_value_by_path as getv

if typing.TYPE_CHECKING:
  import PIL.Image

import pydantic

from . import _api_client
from . import _common
from . import types

logger = logging.getLogger('google_genai._transformers')

if sys.version_info >= (3, 10):
  VersionedUnionType = builtin_types.UnionType
  _UNION_TYPES = (typing.Union, builtin_types.UnionType)
  from typing import TypeGuard
else:
  VersionedUnionType = typing._UnionGenericAlias  # type: ignore[attr-defined]
  _UNION_TYPES = (typing.Union,)
  from typing_extensions import TypeGuard

if typing.TYPE_CHECKING:
  from mcp import ClientSession as McpClientSession
  from mcp.types import Tool as McpTool
else:
  McpClientSession: typing.Type = Any
  McpTool: typing.Type = Any
  try:
    from mcp import ClientSession as McpClientSession
    from mcp.types import Tool as McpTool
  except ImportError:
    McpClientSession = None
    McpTool = None


metric_name_sdk_api_map = {
    'exact_match': 'exactMatchSpec',
    'bleu': 'bleuSpec',
    'rouge_spec': 'rougeSpec',
}
metric_name_api_sdk_map = {v: k for k, v in metric_name_sdk_api_map.items()}


def _is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool:
  """Checks if an object has all of the fields of a Pydantic model.

  This is a duck-typing alternative to `isinstance` to solve dual-import
  problems. It returns False for dictionaries, which should be handled by
  `isinstance(obj, dict)`.

  Args:
    obj: The object to check.
    cls: The Pydantic model class to duck-type against.

  Returns:
    True if the object has all the fields defined in the Pydantic model, False
    otherwise.
  """
  if isinstance(obj, dict) or not hasattr(cls, 'model_fields'):
    return False

  # Check if the object has all of the Pydantic model's defined fields.
  all_matched = all(hasattr(obj, field) for field in cls.model_fields)
  if not all_matched and isinstance(obj, pydantic.BaseModel):
    # Check the other way around if obj is a Pydantic model.
    # Check if the Pydantic model has all of the object's defined fields.
    try:
      obj_private = cls()
      all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields)
    except ValueError:
      return False
  return all_matched


def _resource_name(
    client: _api_client.BaseApiClient,
    resource_name: str,
    *,
    collection_identifier: str,
    collection_hierarchy_depth: int = 2,
) -> str:
  # pylint: disable=line-too-long
  """Prepends resource name with project, location, collection_identifier if needed.

  The collection_identifier will only be prepended if it's not present
  and the prepending won't violate the collection hierarchy depth.
  When the prepending condition doesn't meet, returns the input
  resource_name.

  Args:
    client: The API client.
    resource_name: The user input resource name to be completed.
    collection_identifier: The collection identifier to be prepended. See
      collection identifiers in https://google.aip.dev/122.
    collection_hierarchy_depth: The collection hierarchy depth. Only set this
      field when the resource has nested collections. For example,
      `users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is
      `users` and collection_hierarchy_depth is 4. See nested collections in
      https://google.aip.dev/122.

  Example:

    resource_name = 'cachedContents/123'
    client.vertexai = True
    client.project = 'bar'
    client.location = 'us-west1'
    _resource_name(client, 'cachedContents/123',
      collection_identifier='cachedContents')
    returns: 'projects/bar/locations/us-west1/cachedContents/123'

  Example:

    resource_name = 'projects/foo/locations/us-central1/cachedContents/123'
    # resource_name = 'locations/us-central1/cachedContents/123'
    client.vertexai = True
    client.project = 'bar'
    client.location = 'us-west1'
    _resource_name(client, resource_name,
      collection_identifier='cachedContents')
    returns: 'projects/foo/locations/us-central1/cachedContents/123'

  Example:

    resource_name = '123'
    # resource_name = 'cachedContents/123'
    client.vertexai = False
    _resource_name(client, resource_name,
      collection_identifier='cachedContents')
    returns 'cachedContents/123'

  Example:
    resource_name = 'some/wrong/cachedContents/resource/name/123'
    resource_prefix = 'cachedContents'
    client.vertexai = False
    # client.vertexai = True
    _resource_name(client, resource_name,
      collection_identifier='cachedContents')
    returns: 'some/wrong/cachedContents/resource/name/123'

  Returns:
    The completed resource name.
  """
  should_prepend_collection_identifier = (
      not resource_name.startswith(f'{collection_identifier}/')
      # Check if prepending the collection identifier won't violate the
      # collection hierarchy depth.
      and f'{collection_identifier}/{resource_name}'.count('/') + 1
      == collection_hierarchy_depth
  )
  if client.vertexai:
    if resource_name.startswith('projects/'):
      return resource_name
    elif resource_name.startswith('locations/'):
      return f'projects/{client.project}/{resource_name}'
    elif resource_name.startswith(f'{collection_identifier}/'):
      return f'projects/{client.project}/locations/{client.location}/{resource_name}'
    elif should_prepend_collection_identifier:
      return f'projects/{client.project}/locations/{client.location}/{collection_identifier}/{resource_name}'
    else:
      return resource_name
  else:
    if should_prepend_collection_identifier:
      return f'{collection_identifier}/{resource_name}'
    else:
      return resource_name


def t_model(client: _api_client.BaseApiClient, model: str) -> str:
  if not model:
    raise ValueError('model is required.')
  if '..' in model or '?' in model or '&' in model:
    raise ValueError('invalid model parameter.')
  if client.vertexai:
    if (
        model.startswith('projects/')
        or model.startswith('models/')
        or model.startswith('publishers/')
    ):
      return model
    elif '/' in model:
      publisher, model_id = model.split('/', 1)
      return f'publishers/{publisher}/models/{model_id}'
    else:
      return f'publishers/google/models/{model}'
  else:
    if model.startswith('models/'):
      return model
    elif model.startswith('tunedModels/'):
      return model
    else:
      return f'models/{model}'


def t_models_url(
    api_client: _api_client.BaseApiClient, base_models: bool
) -> str:
  if api_client.vertexai:
    if base_models:
      return 'publishers/google/models'
    else:
      return 'models'
  else:
    if base_models:
      return 'models'
    else:
      return 'tunedModels'


def t_extract_models(
    response: _common.StringDict,
) -> list[_common.StringDict]:
  if not response:
    return []

  models: Optional[list[_common.StringDict]] = response.get('models')
  if models is not None:
    return models

  tuned_models: Optional[list[_common.StringDict]] = response.get('tunedModels')
  if tuned_models is not None:
    return tuned_models

  publisher_models: Optional[list[_common.StringDict]] = response.get(
      'publisherModels'
  )
  if publisher_models is not None:
    return publisher_models

  if (
      response.get('httpHeaders') is not None
      and response.get('jsonPayload') is None
  ):
    return []
  else:
    logger.warning('Cannot determine the models type.')
    logger.debug('Cannot determine the models type for response: %s', response)
    return []


def t_caches_model(
    api_client: _api_client.BaseApiClient, model: str
) -> Optional[str]:
  model = t_model(api_client, model)
  if not model:
    return None
  if model.startswith('publishers/') and api_client.vertexai:
    # vertex caches only support model name start with projects.
    return (
        f'projects/{api_client.project}/locations/{api_client.location}/{model}'
    )
  elif model.startswith('models/') and api_client.vertexai:
    return f'projects/{api_client.project}/locations/{api_client.location}/publishers/google/{model}'
  else:
    return model


def pil_to_blob(img: Any) -> types.Blob:
  PngImagePlugin: Optional[builtin_types.ModuleType]
  try:
    import PIL.PngImagePlugin

    PngImagePlugin = PIL.PngImagePlugin
  except ImportError:
    PngImagePlugin = None

  bytesio = io.BytesIO()
  if (
      PngImagePlugin is not None
      and isinstance(img, PngImagePlugin.PngImageFile)
      or img.mode == 'RGBA'
  ):
    img.save(bytesio, format='PNG')
    mime_type = 'image/png'
  else:
    img.save(bytesio, format='JPEG')
    mime_type = 'image/jpeg'
  bytesio.seek(0)
  data = bytesio.read()
  return types.Blob(mime_type=mime_type, data=data)


def t_function_response(
    function_response: types.FunctionResponseOrDict,
) -> types.FunctionResponse:
  if not function_response:
    raise ValueError('function_response is required.')
  if isinstance(function_response, dict):
    return types.FunctionResponse.model_validate(function_response)
  elif _is_duck_type_of(function_response, types.FunctionResponse):
    return function_response
  else:
    raise TypeError(
        'Could not parse input as FunctionResponse. Unsupported'
        f' function_response type: {type(function_response)}'
    )


def t_function_responses(
    function_responses: Union[
        types.FunctionResponseOrDict,
        Sequence[types.FunctionResponseOrDict],
    ],
) -> list[types.FunctionResponse]:
  if not function_responses:
    raise ValueError('function_responses are required.')
  if isinstance(function_responses, Sequence):
    return [t_function_response(response) for response in function_responses]
  else:
    return [t_function_response(function_responses)]


def t_blobs(
    blobs: Union[types.BlobImageUnionDict, list[types.BlobImageUnionDict]],
) -> list[types.Blob]:
  if isinstance(blobs, list):
    return [t_blob(blob) for blob in blobs]
  else:
    return [t_blob(blobs)]


def t_blob(blob: types.BlobImageUnionDict) -> types.Blob:
  if not blob:
    raise ValueError('blob is required.')

  if _is_duck_type_of(blob, types.Blob):
    return blob  # type: ignore[return-value]

  if isinstance(blob, dict):
    return types.Blob.model_validate(blob)

  if 'image' in blob.__class__.__name__.lower():
    try:
      import PIL.Image

      PIL_Image = PIL.Image.Image
    except ImportError:
      PIL_Image = None

    if PIL_Image is not None and isinstance(blob, PIL_Image):
      return pil_to_blob(blob)

  raise TypeError(
      f'Could not parse input as Blob. Unsupported blob type: {type(blob)}'
  )


def t_image_blob(blob: types.BlobImageUnionDict) -> types.Blob:
  blob = t_blob(blob)
  if blob.mime_type and blob.mime_type.startswith('image/'):
    return blob
  raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')


def t_audio_blob(blob: types.BlobOrDict) -> types.Blob:
  blob = t_blob(blob)
  if blob.mime_type and blob.mime_type.startswith('audio/'):
    return blob
  raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')


def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
  if part is None:
    raise ValueError('content part is required.')
  if isinstance(part, str):
    return types.Part(text=part)
  if _is_duck_type_of(part, types.File):
    if not part.uri or not part.mime_type:  # type: ignore[union-attr]
      raise ValueError('file uri and mime_type are required.')
    return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)  # type: ignore[union-attr]
  if isinstance(part, dict):
    try:
      return types.Part.model_validate(part)
    except pydantic.ValidationError:
      return types.Part(file_data=types.FileData.model_validate(part))
  if _is_duck_type_of(part, types.Part):
    return part  # type: ignore[return-value]

  if 'image' in part.__class__.__name__.lower():
    try:
      import PIL.Image

      PIL_Image = PIL.Image.Image
    except ImportError:
      PIL_Image = None

    if PIL_Image is not None and isinstance(part, PIL_Image):
      return types.Part(inline_data=pil_to_blob(part))
  raise ValueError(f'Unsupported content part type: {type(part)}')


def t_parts(
    parts: Optional[
        Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]
    ],
) -> list[types.Part]:
  #
  if parts is None or (isinstance(parts, list) and not parts):
    raise ValueError('content parts are required.')
  if isinstance(parts, list):
    return [t_part(part) for part in parts]
  else:
    return [t_part(parts)]


def t_image_predictions(
    predictions: Optional[Iterable[Mapping[str, Any]]],
) -> Optional[list[types.GeneratedImage]]:
  if not predictions:
    return None
  images = []
  for prediction in predictions:
    if prediction.get('image'):
      images.append(
          types.GeneratedImage(
              image=types.Image(
                  gcs_uri=prediction['image']['gcsUri'],
                  image_bytes=prediction['image']['imageBytes'],
              )
          )
      )
  return images


ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]


def t_content(
    content: Union[ContentType, types.ContentDict, None],
) -> types.Content:
  if content is None:
    raise ValueError('content is required.')
  if _is_duck_type_of(content, types.Content):
    return content  # type: ignore[return-value]
  if isinstance(content, dict):
    try:
      return types.Content.model_validate(content)
    except pydantic.ValidationError:
      possible_part = t_part(content)  # type: ignore[arg-type]
      return (
          types.ModelContent(parts=[possible_part])
          if possible_part.function_call
          else types.UserContent(parts=[possible_part])
      )
  if _is_duck_type_of(content, types.File):
    return types.UserContent(parts=[t_part(content)])  # type: ignore[arg-type]
  if _is_duck_type_of(content, types.Part):
    return (
        types.ModelContent(parts=[content])  # type: ignore[arg-type]
        if content.function_call  # type: ignore[union-attr]
        else types.UserContent(parts=[content])  # type: ignore[arg-type]
    )
  return types.UserContent(parts=content)  # type: ignore[arg-type]


def t_contents_for_embed(
    client: _api_client.BaseApiClient,
    contents: Union[list[types.Content], list[types.ContentDict], ContentType],
) -> Union[list[str], list[types.Content]]:
  if isinstance(contents, list):
    transformed_contents = [t_content(content) for content in contents]
  else:
    transformed_contents = [t_content(contents)]

  if client.vertexai:
    text_parts = []
    for content in transformed_contents:
      if content is not None:
        if isinstance(content, dict):
          content = types.Content.model_validate(content)
        if content.parts is not None:
          for part in content.parts:
            if part.text:
              text_parts.append(part.text)
            else:
              logger.warning(f'Non-text part found, only returning text parts.')
    return text_parts
  else:
    return transformed_contents


def t_contents(
    contents: Optional[
        Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
    ],
) -> list[types.Content]:
  if contents is None or (isinstance(contents, list) and not contents):
    raise ValueError('contents are required.')
  if not isinstance(contents, list):
    return [t_content(contents)]

  result: list[types.Content] = []
  accumulated_parts: list[types.Part] = []

  def _is_part(
      part: Union[types.PartUnionDict, Any],
  ) -> TypeGuard[types.PartUnionDict]:
    if (
        isinstance(part, str)
        or _is_duck_type_of(part, types.File)
        or _is_duck_type_of(part, types.Part)
    ):
      return True

    if isinstance(part, dict):
      if not part:
        # Empty dict should be considered as Content, not Part.
        return False
      try:
        types.Part.model_validate(part)
        return True
      except pydantic.ValidationError:
        try:
          types.FileData.model_validate(part)
          return True
        except pydantic.ValidationError:
          return False

    if 'image' in part.__class__.__name__.lower():
      try:
        import PIL.Image

        PIL_Image = PIL.Image.Image
      except ImportError:
        PIL_Image = None

      if PIL_Image is not None and isinstance(part, PIL_Image):
        return True

    return False

  def _is_user_part(part: types.Part) -> bool:
    return not part.function_call

  def _are_user_parts(parts: list[types.Part]) -> bool:
    return all(_is_user_part(part) for part in parts)

  def _append_accumulated_parts_as_content(
      result: list[types.Content],
      accumulated_parts: list[types.Part],
  ) -> None:
    if not accumulated_parts:
      return
    result.append(
        types.UserContent(parts=accumulated_parts)
        if _are_user_parts(accumulated_parts)
        else types.ModelContent(parts=accumulated_parts)
    )
    accumulated_parts[:] = []

  def _handle_current_part(
      result: list[types.Content],
      accumulated_parts: list[types.Part],
      current_part: types.PartUnionDict,
  ) -> None:
    current_part = t_part(current_part)
    if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
      accumulated_parts.append(current_part)
    else:
      _append_accumulated_parts_as_content(result, accumulated_parts)
      accumulated_parts[:] = [current_part]

  # iterator over contents
  # if content type or content dict, append to result
  # if consecutive part(s),
  #   group consecutive user part(s) to a UserContent
  #   group consecutive model part(s) to a ModelContent
  #   append to result
  # if list, we only accept a list of types.PartUnion
  for content in contents:
    if _is_duck_type_of(content, types.Content) or isinstance(content, list):
      _append_accumulated_parts_as_content(result, accumulated_parts)
      if isinstance(content, list):
        result.append(types.UserContent(parts=content))  # type: ignore[arg-type]
      else:
        result.append(content)  # type: ignore[arg-type]
    elif _is_part(content):
      _handle_current_part(result, accumulated_parts, content)
    elif isinstance(content, dict):
      # PactDict is already handled in _is_part
      result.append(types.Content.model_validate(content))
    else:
      raise ValueError(f'Unsupported content type: {type(content)}')

  _append_accumulated_parts_as_content(result, accumulated_parts)

  return result


def handle_null_fields(schema: _common.StringDict) -> None:
  """Process null fields in the schema so it is compatible with OpenAPI.

  The OpenAPI spec does not support 'type: 'null' in the schema. This function
  handles this case by adding 'nullable: True' to the null field and removing
  the {'type': 'null'} entry.

  https://swagger.io/docs/specification/v3_0/data-models/data-types/#null

  Example of schema properties before and after handling null fields:
    Before:
      {
        "name": {
          "title": "Name",
          "type": "string"
        },
        "total_area_sq_mi": {
          "anyOf": [
            {
              "type": "integer"
            },
            {
              "type": "null"
            }
          ],
          "default": None,
          "title": "Total Area Sq Mi"
        }
      }

    After:
      {
        "name": {
          "title": "Name",
          "type": "string"
        },
        "total_area_sq_mi": {
          "type": "integer",
          "nullable": true,
          "default": None,
          "title": "Total Area Sq Mi"
        }
      }
  """
  if schema.get('type', None) == 'null':
    schema['nullable'] = True
    del schema['type']
  elif 'anyOf' in schema:
    for item in schema['anyOf']:
      if 'type' in item and item['type'] == 'null':
        schema['nullable'] = True
        schema['anyOf'].remove({'type': 'null'})
        if len(schema['anyOf']) == 1:
          # If there is only one type left after removing null, remove the anyOf field.
          for key, val in schema['anyOf'][0].items():
            schema[key] = val
          del schema['anyOf']


def _raise_for_unsupported_schema_type(origin: Any) -> None:
  """Raises an error if the schema type is unsupported."""
  raise ValueError(f'Unsupported schema type: {origin}')


def _raise_for_unsupported_mldev_properties(
    schema: Any, client: Optional[_api_client.BaseApiClient]
) -> None:
  if (
      client
      and not client.vertexai
      and (
          schema.get('additionalProperties')
          or schema.get('additional_properties')
      )
  ):
    raise ValueError('additionalProperties is not supported in the Gemini API.')


def process_schema(
    schema: _common.StringDict,
    client: Optional[_api_client.BaseApiClient],
    defs: Optional[_common.StringDict] = None,
    *,
    order_properties: bool = True,
) -> None:
  """Updates the schema and each sub-schema inplace to be API-compatible.

  - Inlines the $defs.

  Example of a schema before and after (with mldev):
    Before:

    `schema`

    {
        'items': {
            '$ref': '#/$defs/CountryInfo'
        },
        'title': 'Placeholder',
        'type': 'array'
    }


    `defs`

    {
      'CountryInfo': {
        'properties': {
          'continent': {
              'title': 'Continent',
              'type': 'string'
          },
          'gdp': {
              'title': 'Gdp',
              'type': 'integer'}
          },
        }
        'required':['continent', 'gdp'],
        'title': 'CountryInfo',
        'type': 'object'
      }
    }

    After:

    `schema`
     {
        'items': {
          'properties': {
            'continent': {
              'title': 'Continent',
              'type': 'string'
            },
            'gdp': {
              'title': 'Gdp',
              'type': 'integer'
            },
          }
          'required':['continent', 'gdp'],
          'title': 'CountryInfo',
          'type': 'object'
        },
        'type': 'array'
    }
  """
  if schema.get('title') == 'PlaceholderLiteralEnum':
    del schema['title']

  _raise_for_unsupported_mldev_properties(schema, client)

  # Standardize spelling for relevant schema fields.  For example, if a dict is
  # provided directly to response_schema, it may use `any_of` instead of `anyOf.
  # Otherwise, model_json_schema() uses `anyOf`.
  for from_name, to_name in [
      ('additional_properties', 'additionalProperties'),
      ('any_of', 'anyOf'),
      ('prefix_items', 'prefixItems'),
      ('property_ordering', 'propertyOrdering'),
  ]:
    if (value := schema.pop(from_name, None)) is not None:
      schema[to_name] = value

  if defs is None:
    defs = schema.pop('$defs', {})
    for _, sub_schema in defs.items():
      # We can skip the '$ref' check, because JSON schema forbids a '$ref' from
      # directly referencing another '$ref':
      # https://json-schema.org/understanding-json-schema/structuring#recursion
      process_schema(
          sub_schema, client, defs, order_properties=order_properties
      )

  handle_null_fields(schema)

  # After removing null fields, Optional fields with only one possible type
  # will have a $ref key that needs to be flattened
  # For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
  if (ref := schema.pop('$ref', None)) is not None:
    schema.update(defs[ref.split('defs/')[-1]])

  def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
    """Returns the processed `sub_schema`, resolving its '$ref' if any."""
    if (ref := sub_schema.pop('$ref', None)) is not None:
      sub_schema = defs[ref.split('defs/')[-1]]
    process_schema(sub_schema, client, defs, order_properties=order_properties)
    return sub_schema

  if (any_of := schema.get('anyOf')) is not None:
    schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of]
    return

  schema_type = schema.get('type')
  if isinstance(schema_type, Enum):
    schema_type = schema_type.value
  if isinstance(schema_type, str):
    schema_type = schema_type.upper()

  # model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field
  # For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'}
  const = schema.get('const')
  if const is not None:
    if schema_type == 'STRING':
      schema['enum'] = [const]
      del schema['const']
    else:
      raise ValueError('Literal values must be strings.')

  if schema_type == 'OBJECT':
    if (properties := schema.get('properties')) is not None:
      for name, sub_schema in list(properties.items()):
        properties[name] = _recurse(sub_schema)
      if (
          len(properties.items()) > 1
          and order_properties
          and 'propertyOrdering' not in schema
      ):
        schema['property_ordering'] = list(properties.keys())
    if (additional := schema.get('additionalProperties')) is not None:
      # It is legal to set 'additionalProperties' to a bool:
      # https://json-schema.org/understanding-json-schema/reference/object#additionalproperties
      if isinstance(additional, dict):
        schema['additionalProperties'] = _recurse(additional)
  elif schema_type == 'ARRAY':
    if (items := schema.get('items')) is not None:
      schema['items'] = _recurse(items)
    if (prefixes := schema.get('prefixItems')) is not None:
      schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes]


def _process_enum(
    enum: EnumMeta, client: Optional[_api_client.BaseApiClient]
) -> types.Schema:
  is_integer_enum = False

  for member in enum:  # type: ignore
    if isinstance(member.value, int):
      is_integer_enum = True
    elif not isinstance(member.value, str):
      raise TypeError(
          f'Enum member {member.name} value must be a string or integer, got'
          f' {type(member.value)}'
      )

  enum_to_process = enum
  if is_integer_enum:
    str_members = [str(member.value) for member in enum]  # type: ignore
    str_enum = Enum(enum.__name__, str_members, type=str)  # type: ignore
    enum_to_process = str_enum

  class Placeholder(pydantic.BaseModel):
    placeholder: enum_to_process  # type: ignore[valid-type]

  enum_schema = Placeholder.model_json_schema()
  process_schema(enum_schema, client)
  enum_schema = enum_schema['properties']['placeholder']
  return types.Schema.model_validate(enum_schema)


def _is_type_dict_str_any(
    origin: Union[types.SchemaUnionDict, Any],
) -> TypeGuard[_common.StringDict]:
  """Verifies the schema is of type dict[str, Any] for mypy type checking."""
  return isinstance(origin, dict) and all(
      isinstance(key, str) for key in origin
  )


def t_schema(
    client: Optional[_api_client.BaseApiClient],
    origin: Union[types.SchemaUnionDict, Any],
) -> Optional[types.Schema]:
  if not origin:
    return None
  if isinstance(origin, dict) and _is_type_dict_str_any(origin):
    process_schema(origin, client)
    return types.Schema.model_validate(origin)
  if isinstance(origin, EnumMeta):
    return _process_enum(origin, client)
  if _is_duck_type_of(origin, types.Schema):
    if dict(origin) == dict(types.Schema()):  # type: ignore [arg-type]
      # response_schema value was coerced to an empty Schema instance because
      # it did not adhere to the Schema field annotation
      _raise_for_unsupported_schema_type(origin)
    schema = origin.model_dump(exclude_unset=True)  # type: ignore[union-attr]
    process_schema(schema, client)
    return types.Schema.model_validate(schema)

  if (
      # in Python 3.9 Generic alias list[int] counts as a type,
      # and breaks issubclass because it's not a class.
      not isinstance(origin, GenericAlias)
      and isinstance(origin, type)
      and issubclass(origin, pydantic.BaseModel)
  ):
    schema = origin.model_json_schema()
    process_schema(schema, client)
    return types.Schema.model_validate(schema)
  elif (
      isinstance(origin, GenericAlias)
      or isinstance(origin, type)
      or isinstance(origin, VersionedUnionType)
      or typing.get_origin(origin) in _UNION_TYPES
  ):

    class Placeholder(pydantic.BaseModel):
      placeholder: origin  # type: ignore[valid-type]

    schema = Placeholder.model_json_schema()
    process_schema(schema, client)
    schema = schema['properties']['placeholder']
    return types.Schema.model_validate(schema)

  raise ValueError(f'Unsupported schema type: {origin}')


def t_speech_config(
    origin: Union[types.SpeechConfigUnionDict, Any],
) -> Optional[types.SpeechConfig]:
  if not origin:
    return None
  if _is_duck_type_of(origin, types.SpeechConfig):
    return origin  # type: ignore[return-value]
  if isinstance(origin, str):
    return types.SpeechConfig(
        voice_config=types.VoiceConfig(
            prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
        )
    )
  if isinstance(origin, dict):
    return types.SpeechConfig.model_validate(origin)

  raise ValueError(f'Unsupported speechConfig type: {type(origin)}')


def t_live_speech_config(
    origin: types.SpeechConfigOrDict,
) -> Optional[types.SpeechConfig]:
  if _is_duck_type_of(origin, types.SpeechConfig):
    speech_config = origin
  if isinstance(origin, dict):
    speech_config = types.SpeechConfig.model_validate(origin)

  if speech_config.multi_speaker_voice_config is not None:  # type: ignore[union-attr]
    raise ValueError(
        'multi_speaker_voice_config is not supported in the live API.'
    )

  return speech_config  # type: ignore[return-value]


def t_tool(
    client: _api_client.BaseApiClient, origin: Any
) -> Optional[Union[types.Tool, Any]]:
  if not origin:
    return None
  if inspect.isfunction(origin) or inspect.ismethod(origin):
    return types.Tool(
        function_declarations=[
            types.FunctionDeclaration.from_callable(
                client=client, callable=origin
            )
        ]
    )
  elif McpTool is not None and _is_duck_type_of(origin, McpTool):
    return mcp_to_gemini_tool(origin)
  elif isinstance(origin, dict):
    return types.Tool.model_validate(origin)
  else:
    return origin


def t_tools(
    client: _api_client.BaseApiClient, origin: list[Any]
) -> list[types.Tool]:
  if not origin:
    return []
  function_tool = types.Tool(function_declarations=[])
  tools = []
  for tool in origin:
    transformed_tool = t_tool(client, tool)
    # All functions should be merged into one tool.
    if transformed_tool is not None:
      if (
          transformed_tool.function_declarations
          and function_tool.function_declarations is not None
      ):
        function_tool.function_declarations += (
            transformed_tool.function_declarations
        )
      else:
        tools.append(transformed_tool)
  if function_tool.function_declarations:
    tools.append(function_tool)
  return tools


def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
  return _resource_name(client, name, collection_identifier='cachedContents')


def t_batch_job_source(
    client: _api_client.BaseApiClient,
    src: types.BatchJobSourceUnionDict,
) -> types.BatchJobSource:
  if isinstance(src, dict):
    src = types.BatchJobSource(**src)
  if _is_duck_type_of(src, types.BatchJobSource):
    vertex_sources = sum(
        [src.gcs_uri is not None, src.bigquery_uri is not None]  # type: ignore[union-attr]
    )
    mldev_sources = sum([
        src.inlined_requests is not None,  # type: ignore[union-attr]
        src.file_name is not None,  # type: ignore[union-attr]
    ])
    if client.vertexai:
      if mldev_sources or vertex_sources != 1:
        raise ValueError(
            'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
            'sources are not supported in Vertex AI.'
        )
    else:
      if vertex_sources or mldev_sources != 1:
        raise ValueError(
            'Exactly one of `inlined_requests`, `file_name`, '
            '`inlined_embed_content_requests`, or `embed_content_file_name` '
            'must be set, other sources are not supported in Gemini API.'
        )
    return src  # type: ignore[return-value]

  elif isinstance(src, list):
    return types.BatchJobSource(inlined_requests=src)
  elif isinstance(src, str):
    if src.startswith('gs://'):
      return types.BatchJobSource(
          format='jsonl',
          gcs_uri=[src],
      )
    elif src.startswith('bq://'):
      return types.BatchJobSource(
          format='bigquery',
          bigquery_uri=src,
      )
    elif src.startswith('files/'):
      return types.BatchJobSource(
          file_name=src,
      )

  raise ValueError(f'Unsupported source: {src}')


def t_embedding_batch_job_source(
    client: _api_client.BaseApiClient,
    src: types.EmbeddingsBatchJobSourceOrDict,
) -> types.EmbeddingsBatchJobSource:
  if isinstance(src, dict):
    src = types.EmbeddingsBatchJobSource(**src)

  if _is_duck_type_of(src, types.EmbeddingsBatchJobSource):
    mldev_sources = sum([
        src.inlined_requests is not None,
        src.file_name is not None,
    ])
    if mldev_sources != 1:
      raise ValueError(
          'Exactly one of `inlined_requests`, `file_name`, '
          '`inlined_embed_content_requests`, or `embed_content_file_name` '
          'must be set, other sources are not supported in Gemini API.'
      )
    return src
  else:
    raise ValueError(f'Unsupported source type: {type(src)}')


def t_batch_job_destination(
    dest: Union[str, types.BatchJobDestinationOrDict],
) -> types.BatchJobDestination:
  if isinstance(dest, dict):
    dest = types.BatchJobDestination(**dest)
    return dest
  elif isinstance(dest, str):
    if dest.startswith('gs://'):
      return types.BatchJobDestination(
          format='jsonl',
          gcs_uri=dest,
      )
    elif dest.startswith('bq://'):
      return types.BatchJobDestination(
          format='bigquery',
          bigquery_uri=dest,
      )
    else:
      raise ValueError(f'Unsupported destination: {dest}')
  elif _is_duck_type_of(dest, types.BatchJobDestination):
    return dest
  else:
    raise ValueError(f'Unsupported destination: {dest}')


def t_recv_batch_job_destination(dest: dict[str, Any]) -> dict[str, Any]:
  # Rename inlinedResponses if it looks like an embedding response.
  inline_responses = dest.get('inlinedResponses', {}).get(
      'inlinedResponses', []
  )
  if not inline_responses:
    return dest
  for response in inline_responses:
    inner_response = response.get('response', {})
    if not inner_response:
      continue
    if 'embedding' in inner_response:
      dest['inlinedEmbedContentResponses'] = dest.pop('inlinedResponses')
      break
  return dest


def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
  if not client.vertexai:
    mldev_pattern = r'batches/[^/]+$'
    if re.match(mldev_pattern, name):
      return name.split('/')[-1]
    else:
      raise ValueError(f'Invalid batch job name: {name}.')

  vertex_pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'

  if re.match(vertex_pattern, name):
    return name.split('/')[-1]
  elif name.isdigit():
    return name
  else:
    raise ValueError(f'Invalid batch job name: {name}.')


def t_job_state(state: str) -> str:
  if state == 'BATCH_STATE_UNSPECIFIED':
    return 'JOB_STATE_UNSPECIFIED'
  elif state == 'BATCH_STATE_PENDING':
    return 'JOB_STATE_PENDING'
  elif state == 'BATCH_STATE_RUNNING':
    return 'JOB_STATE_RUNNING'
  elif state == 'BATCH_STATE_SUCCEEDED':
    return 'JOB_STATE_SUCCEEDED'
  elif state == 'BATCH_STATE_FAILED':
    return 'JOB_STATE_FAILED'
  elif state == 'BATCH_STATE_CANCELLED':
    return 'JOB_STATE_CANCELLED'
  elif state == 'BATCH_STATE_EXPIRED':
    return 'JOB_STATE_EXPIRED'
  else:
    return state


LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
LRO_POLLING_TIMEOUT_SECONDS = 900.0
LRO_POLLING_MULTIPLIER = 1.5


def t_resolve_operation(
    api_client: _api_client.BaseApiClient, struct: _common.StringDict
) -> Any:
  if (name := struct.get('name')) and '/operations/' in name:
    operation: _common.StringDict = struct
    total_seconds = 0.0
    delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS
    while operation.get('done') != True:
      if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
        raise RuntimeError(f'Operation {name} timed out.\n{operation}')
      # TODO(b/374433890): Replace with LRO module once it's available.
      operation = api_client.request(  # type: ignore[assignment]
          http_method='GET', path=name, request_dict={}
      )
      time.sleep(delay_seconds)
      total_seconds += total_seconds
      # Exponential backoff
      delay_seconds = min(
          delay_seconds * LRO_POLLING_MULTIPLIER,
          LRO_POLLING_MAXIMUM_DELAY_SECONDS,
      )
    if error := operation.get('error'):
      raise RuntimeError(
          f'Operation {name} failed with error: {error}.\n{operation}'
      )
    return operation.get('response')
  else:
    return struct


def t_file_name(
    name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
) -> str:
  # Remove the files/ prefix since it's added to the url path.
  if _is_duck_type_of(name, types.File):
    name = name.name  # type: ignore[union-attr]
  elif _is_duck_type_of(name, types.Video):
    name = name.uri  # type: ignore[union-attr]
  elif _is_duck_type_of(name, types.GeneratedVideo):
    if name.video is not None:  # type: ignore[union-attr]
      name = name.video.uri  # type: ignore[union-attr]
    else:
      name = None

  if name is None:
    raise ValueError('File name is required.')

  if not isinstance(name, str):
    raise ValueError(
        f'Could not convert object of type `{type(name)}` to a file name.'
    )

  if name.startswith('https://'):
    suffix = name.split('files/')[1]
    match = re.match('[a-z0-9]+', suffix)
    if match is None:
      raise ValueError(f'Could not extract file name from URI: {name}')
    name = match.group(0)
  elif name.startswith('files/'):
    name = name.split('files/')[1]

  return name


def t_tuning_job_status(status: str) -> Union[types.JobState, str]:
  if status == 'STATE_UNSPECIFIED':
    return types.JobState.JOB_STATE_UNSPECIFIED
  elif status == 'CREATING':
    return types.JobState.JOB_STATE_RUNNING
  elif status == 'ACTIVE':
    return types.JobState.JOB_STATE_SUCCEEDED
  elif status == 'FAILED':
    return types.JobState.JOB_STATE_FAILED
  else:
    for state in types.JobState:
      if str(state.value) == status:
        return state
    return status


def t_content_strict(content: types.ContentOrDict) -> types.Content:
  if isinstance(content, dict):
    return types.Content.model_validate(content)
  elif _is_duck_type_of(content, types.Content):
    return content
  else:
    raise ValueError(
        f'Could not convert input (type "{type(content)}") to `types.Content`'
    )


def t_contents_strict(
    contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict],
) -> list[types.Content]:
  if isinstance(contents, Sequence):
    return [t_content_strict(content) for content in contents]
  else:
    return [t_content_strict(contents)]


def t_client_content(
    turns: Optional[
        Union[Sequence[types.ContentOrDict], types.ContentOrDict]
    ] = None,
    turn_complete: bool = True,
) -> types.LiveClientContent:
  if turns is None:
    return types.LiveClientContent(turn_complete=turn_complete)

  try:
    return types.LiveClientContent(
        turns=t_contents_strict(contents=turns),
        turn_complete=turn_complete,
    )
  except Exception as e:
    raise ValueError(
        f'Could not convert input (type "{type(turns)}") to '
        '`types.LiveClientContent`'
    ) from e


def t_tool_response(
    input: Union[
        types.FunctionResponseOrDict,
        Sequence[types.FunctionResponseOrDict],
    ],
) -> types.LiveClientToolResponse:
  if not input:
    raise ValueError(f'A tool response is required, got: \n{input}')

  try:
    return types.LiveClientToolResponse(
        function_responses=t_function_responses(function_responses=input)
    )
  except Exception as e:
    raise ValueError(
        f'Could not convert input (type "{type(input)}") to '
        '`types.LiveClientToolResponse`'
    ) from e


def t_metrics(
    metrics: list[types.MetricSubclass]
) -> list[dict[str, Any]]:
    """Prepares the metric payload for the evaluation request.

    Args:
        request_dict: The dictionary containing the request details.
        resolved_metrics: A list of resolved metric objects.

    Returns:
        The updated request dictionary with the prepared metric payload.
    """
    metrics_payload = []

    for metric in metrics:
      metric_payload_item: dict[str, Any] = {}
      metric_payload_item['aggregation_metrics'] = [
          'AVERAGE',
          'STANDARD_DEVIATION',
      ]

      metric_name = getv(metric, ['name']).lower()

      if metric_name == 'exact_match':
        metric_payload_item['exact_match_spec'] = {}
      elif metric_name == 'bleu':
        metric_payload_item['bleu_spec'] = {}
      elif metric_name.startswith('rouge'):
        rouge_type = metric_name.replace("_", "")
        metric_payload_item['rouge_spec'] = {'rouge_type': rouge_type}

      elif hasattr(metric, 'prompt_template') and metric.prompt_template:
        pointwise_spec = {'metric_prompt_template': metric.prompt_template}
        system_instruction = getv(
            metric, ['judge_model_system_instruction']
        )
        if system_instruction:
          pointwise_spec['system_instruction'] = system_instruction
        return_raw_output = getv(
            metric, ['return_raw_output']
        )
        if return_raw_output:
          pointwise_spec['custom_output_format_config'] = {  # type: ignore[assignment]
              'return_raw_output': return_raw_output
          }
        metric_payload_item['pointwise_metric_spec'] = pointwise_spec
      else:
        raise ValueError(
            'Unsupported metric type or invalid metric name:' f' {metric_name}'
        )
      metrics_payload.append(metric_payload_item)
    return metrics_payload
